{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![MLU Logo](../data/MLU_Logo.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <a name=\"0\">Machine Learning Accelerator - Natural Language Processing - Lecture 3</a>\n",
    "\n",
    "## Recurrent Neural Networks (RNNs) for the Product Review Problem - Classify Product Reviews as Positive or Not\n",
    "\n",
    "In this exercise, we will learn how to use Recurrent Neural Networks. \n",
    "\n",
    "We will follow these steps:\n",
    "1. <a href=\"#1\">Reading the dataset</a>\n",
    "2. <a href=\"#2\">Exploratory data analysis</a>\n",
    "3. <a href=\"#3\">Train-validation dataset split</a>\n",
    "4. <a href=\"#4\">Text processing and Transformation</a>\n",
    "5. <a href=\"#5\">Generating data batch and iterator</a>\n",
    "6. <a href=\"#6\">Using pre-trained GloVe Word Embeddings</a>\n",
    "7. <a href=\"#7\">Setting Hyperparameters and Bulding the Network</a>\n",
    "8. <a href=\"#8\">Training the Network</a>\n",
    "9. <a href=\"#9\">Test the classifier on the validation data</a>\n",
    "10. <a href=\"#10\">Improvement ideas</a>\n",
    "\n",
    "Overall dataset schema:\n",
    "* __reviewText:__ Text of the review\n",
    "* __summary:__ Summary of the review\n",
    "* __verified:__ Whether the purchase was verified (True or False)\n",
    "* __time:__ UNIX timestamp for the review\n",
    "* __log_votes:__ Logarithm-adjusted votes log(1+votes)\n",
    "* __isPositive:__ Whether the review is positive or negative (1 or 0)\n",
    "\n",
    "__Important note:__ One big distinction betweeen the regular neural networks and RNNs is that RNNs work with sequential data. In our case, RNNs will help us with the text field. If we also want to consider other fields such as time, log_votes, verified, etc. , we need to use the regular neural networks with the RNN network."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to install additional required libraries (for more details see this page [here](https://github.com/pytorch/text#installation))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -q -r ../requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:48.342987Z",
     "start_time": "2021-01-09T05:02:47.164823Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch, torchtext\n",
    "import pandas as pd\n",
    "from collections import Counter\n",
    "from torch import nn, optim\n",
    "from torch.nn import BCEWithLogitsLoss\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext.vocab import GloVe\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import confusion_matrix, classification_report, accuracy_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. <a name=\"1\">Reading the dataset</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's read the dataset below and fill-in the reviewText field. We will use this field as input to our ML model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:48.995226Z",
     "start_time": "2021-01-09T05:02:48.344888Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../data/examples/AMAZON-REVIEW-DATA-CLASSIFICATION.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the first five rows in the dataset. As you can see the __log_votes__ field is numeric. That's why we will build a regression model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:49.015545Z",
     "start_time": "2021-01-09T05:02:48.997444Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reviewText</th>\n",
       "      <th>summary</th>\n",
       "      <th>verified</th>\n",
       "      <th>time</th>\n",
       "      <th>log_votes</th>\n",
       "      <th>isPositive</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>PURCHASED FOR YOUNGSTER WHO\\nINHERITED MY \"TOO...</td>\n",
       "      <td>IDEAL FOR BEGINNER!</td>\n",
       "      <td>True</td>\n",
       "      <td>1361836800</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>unable to open or use</td>\n",
       "      <td>Two Stars</td>\n",
       "      <td>True</td>\n",
       "      <td>1452643200</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Waste of money!!! It wouldn't load to my system.</td>\n",
       "      <td>Dont buy it!</td>\n",
       "      <td>True</td>\n",
       "      <td>1433289600</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>I attempted to install this OS on two differen...</td>\n",
       "      <td>I attempted to install this OS on two differen...</td>\n",
       "      <td>True</td>\n",
       "      <td>1518912000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>I've spent 14 fruitless hours over the past tw...</td>\n",
       "      <td>Do NOT Download.</td>\n",
       "      <td>True</td>\n",
       "      <td>1441929600</td>\n",
       "      <td>1.098612</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                          reviewText  \\\n",
       "0  PURCHASED FOR YOUNGSTER WHO\\nINHERITED MY \"TOO...   \n",
       "1                              unable to open or use   \n",
       "2   Waste of money!!! It wouldn't load to my system.   \n",
       "3  I attempted to install this OS on two differen...   \n",
       "4  I've spent 14 fruitless hours over the past tw...   \n",
       "\n",
       "                                             summary  verified        time  \\\n",
       "0                                IDEAL FOR BEGINNER!      True  1361836800   \n",
       "1                                          Two Stars      True  1452643200   \n",
       "2                                       Dont buy it!      True  1433289600   \n",
       "3  I attempted to install this OS on two differen...      True  1518912000   \n",
       "4                                   Do NOT Download.      True  1441929600   \n",
       "\n",
       "   log_votes  isPositive  \n",
       "0   0.000000         1.0  \n",
       "1   0.000000         0.0  \n",
       "2   0.000000         0.0  \n",
       "3   0.000000         0.0  \n",
       "4   1.098612         0.0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. <a name=\"2\">Exploratory Data Analysis</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the range and distribution of log_votes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:49.024615Z",
     "start_time": "2021-01-09T05:02:49.017492Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "isPositive\n",
       "1.0    43692\n",
       "0.0    26308\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"isPositive\"].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check the number of missing values for each columm below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:49.040120Z",
     "start_time": "2021-01-09T05:02:49.026288Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "reviewText    12\n",
      "summary       15\n",
      "verified       0\n",
      "time           0\n",
      "log_votes      0\n",
      "isPositive     0\n",
      "dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(df.isna().sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have missing values in our text fields. We can the rows with __reviewText__ field missing as we will use that field with the model aftwerwards."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = df.dropna(subset=['reviewText'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "reviewText     0\n",
      "summary       13\n",
      "verified       0\n",
      "time           0\n",
      "log_votes      0\n",
      "isPositive     0\n",
      "dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(df.isna().sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. <a name=\"3\">Train-validation split</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's split the dataset into training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:49.098503Z",
     "start_time": "2021-01-09T05:02:49.041948Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# This separates 10% of the entire dataset into validation dataset.\n",
    "train_text, val_text, train_label, val_label = \\\n",
    "    train_test_split(df[\"reviewText\"].tolist(),\n",
    "                     df[\"isPositive\"].tolist(),\n",
    "                     test_size=0.10,\n",
    "                     shuffle=True,\n",
    "                     random_state=324)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. <a name=\"4\">Text processing and Transformation</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "We will apply the following processes here:\n",
    "1. Creating a vocabulary\n",
    "2. Text transformation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__1. Creating a vocabulary:__ \n",
    "\n",
    "We will create a vocabulary with the tokens from the text data. We use a simple english tokenizer and use these tokens to create our vocabulary. In this vocabulary, tokens will map to unique ids, such as \"car\"->32, \"house\"->651, etc. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = get_tokenizer(\"basic_english\")\n",
    "counter = Counter()\n",
    "for line in train_text:\n",
    "    counter.update(tokenizer(line))\n",
    "    \n",
    "# Create a vocabulary with words seen at least 5 (min_freq) times\n",
    "vocab = torchtext.vocab.vocab(counter, min_freq=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:02:50.025716Z",
     "start_time": "2021-01-09T05:02:49.274150Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Add the unknown token\n",
    "# and use it by default for unknown words\n",
    "unk_token = '<unk>'\n",
    "vocab.insert_token(unk_token, 0)\n",
    "vocab.set_default_index(0)\n",
    "\n",
    "# Add the pad token\n",
    "pad_token = '<pad>'\n",
    "vocab.insert_token(pad_token, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are some examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:03:26.376574Z",
     "start_time": "2021-01-09T05:02:50.027397Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'home' -> 559\n",
      "'wash' -> 7002\n",
      "'fhshbasdhb' -> 0\n"
     ]
    }
   ],
   "source": [
    "print(f\"'home' -> {vocab['home']}\")\n",
    "print(f\"'wash' -> {vocab['wash']}\")\n",
    "# unknown word (assume from test set)\n",
    "print(f\"'fhshbasdhb' -> {vocab['fhshbasdhb']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.369651Z",
     "start_time": "2021-01-09T05:03:26.378501Z"
    }
   },
   "source": [
    "__2. Text transformation:__ \n",
    "\n",
    "We will use the vocabulary and map tokens in the text to unique ids of the tokens. For example: `[\"this\", \"is\", \"a\", \"sentence\"] -> [14, 12, 9, 2066]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Let's create a mapper to transform our text data\n",
    "text_transform_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.374706Z",
     "start_time": "2021-01-09T05:04:29.371615Z"
    }
   },
   "source": [
    "Let's see some text before and after transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before transform:\tCrazy Machines should be relevant to anyone age 8 and up. However, younger gamers may need a little help from their parents when getting started, as well as with some of the more obtuse puzzle solving scenarios built into the game. The marvelous thing about Crazy Machines is its design-your-own experiment setting that will allow you free-form control over tons and tons of the game's many wacky gadgets that can be combined in all sorts of ways to accomplish literally any challenge you can imagine.\n",
      "\n",
      "The game's pre-designed puzzles will challenge most all players, so very young players will also likely need parental help with the challenges that are a bit more complicated.\n",
      "\n",
      "Game fans who remember The Incredible Machine and its successors will most certainly be very pleased to see how this new Crazy Machines approach to this unique game idea plays out. What made The Incredible Machine so much fun and successful was the tremendous variety of gizmos and gadgets you could combine in weird Rube Goldberg concoctions to get something done. Now, Crazy Machines advances this idea still further, and succeeds with its iteration of this idea that was so solidly established by The Incredible Machine.\n",
      "\n",
      "Kudos to the game designer studio team for taking on this challenge and providing for gamers a new approach to the whimsy and challenging fun that we loved so much in The Incredible Machine.\n",
      "After transform:\t[1032, 1033, 394, 142, 1034, 25, 856, 1035, 1036, 13, 218, 18, 461, 27, 1037, 1038, 237, 126, 52, 376, 155, 157, 691, 1039, 204, 644, 300, 27, 32, 213, 32, 6, 277, 64, 8, 77, 1040, 1041, 1042, 1043, 710, 402, 8, 803, 18, 8, 1044, 230, 153, 1032, 1033, 11, 603, 0, 1045, 1020, 139, 177, 469, 154, 0, 1046, 525, 1047, 13, 1047, 64, 8, 803, 15, 304, 232, 1048, 1049, 139, 111, 142, 1050, 120, 236, 1051, 64, 1052, 25, 1053, 1054, 573, 1055, 154, 111, 1056, 18, 8, 803, 15, 304, 1057, 1058, 177, 1055, 240, 236, 1059, 27, 48, 458, 1060, 1059, 177, 369, 1061, 126, 1062, 155, 6, 8, 1063, 139, 95, 52, 497, 77, 1064, 18, 803, 1065, 479, 1066, 8, 1067, 845, 13, 603, 0, 177, 240, 1068, 142, 458, 769, 25, 371, 309, 169, 384, 1032, 1033, 1069, 25, 169, 1070, 803, 1071, 1072, 114, 18, 175, 719, 8, 1067, 845, 48, 166, 743, 13, 1073, 91, 8, 1074, 1075, 64, 0, 13, 1049, 154, 433, 1076, 120, 1077, 0, 0, 0, 25, 172, 792, 1078, 18, 146, 27, 1032, 1033, 1079, 169, 1071, 96, 1080, 27, 13, 1081, 6, 603, 1082, 64, 169, 1071, 139, 91, 48, 1083, 1084, 625, 8, 1067, 845, 18, 1085, 25, 8, 803, 1086, 1087, 1088, 39, 1089, 99, 169, 1055, 13, 1090, 39, 1038, 52, 384, 1069, 25, 8, 0, 13, 1091, 743, 139, 43, 1092, 48, 166, 120, 8, 1067, 845, 18]\n"
     ]
    }
   ],
   "source": [
    "print(f\"Before transform:\\t{train_text[37]}\")\n",
    "print(f\"After transform:\\t{text_transform_pipeline(train_text[37])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a function for this. In this function, we transform and pad (if necessary) our text data. We cut the series of words at the point where it reaches a certain lenght (we used `max_len=50` here). If the text is shorter than max_len, we `pad 1s` to the end (corresponding to the pad token)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def transformText(text_list, max_len):\n",
    "    # Transform the text\n",
    "    transformed_data = [text_transform_pipeline(text)[:max_len] for text in text_list]\n",
    "\n",
    "    # Pad zeros if the text is shoter than max_len\n",
    "    for data in transformed_data:\n",
    "        data[len(data) : max_len] = np.ones(max_len - len(data))\n",
    "\n",
    "    return torch.tensor(transformed_data, dtype=torch.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"I had just purchased a new laptop, which had to be returned to manufacturer to replace webcam. This is after I had spent 5 days loading it with all of the new software and personal data. Faced with the dilemma of having to repeat this 5 day process or the security risk of leaving my data on the SSD disk, I just backup'd the disk and restored the laptop disk to its factory condition. After the laptop was repaired, it took only a few minutes to recovery everything from my backup which had been stored on an external drive. FYI, Windows 10 recovery is a complete non-starter for this type of backup.\""
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_text[129]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Text: [\"I purchased the Starter Edition, not realizing that it wouldn't import data from a previous version, which is what I needed.  That is the only thing that I needed from Quicken Deluxe that the starter edition didn't have.  Quicken Deluxe has many features which I don't use at all and which I may never need.\\nThe starter edition is probably the most that many people need when they first start out, but they need to know that they will be unable to buy the next upgrade without jumping up to the Deluxe edition.\", 'This gel bed  was comfortable. I felt cool and not hot at night.']\n",
      "\n",
      "Num sentences: 2\n",
      "\n",
      "Transformed text: \n",
      "tensor([[ 14,  35,   8, 221, 188,  27, 222, 223, 139,  20, 224,  15, 112, 225,\n",
      "         226, 157,  52, 227, 228,  27,  51,  11, 175,  14, 229,  18, 139,  11,\n",
      "           8, 148, 230, 139,  14, 229, 157,   4, 231, 139,   8, 221, 188, 125,\n",
      "          15, 112,  58,  18,   4, 231, 150, 232],\n",
      "        [169,   0, 250,  91, 251,  18,  14, 252, 253,  13, 222, 254, 235, 255,\n",
      "          18,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,\n",
      "           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,\n",
      "           1,   1,   1,   1,   1,   1,   1,   1]])\n",
      "\n",
      "Shape of transformed text: torch.Size([2, 50])\n"
     ]
    }
   ],
   "source": [
    "text = train_text[5:7]\n",
    "print(f\"Text: {text}\\n\")\n",
    "print(f\"Num sentences: {len(text)}\\n\")\n",
    "tt = transformText(text, max_len=50)\n",
    "print(f\"Transformed text: \\n{tt}\\n\")\n",
    "print(f\"Shape of transformed text: {tt.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. <a name=\"5\">Generating data batch and iterator</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "Let's use the transformText() function and create the data loaders. Here, we use __max_len=100__ to consider the first 100 words in the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.864398Z",
     "start_time": "2021-01-09T05:04:29.376025Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "max_len = 100\n",
    "batch_size = 16\n",
    "\n",
    "# Pass transformed and padded data to dataset\n",
    "# Create data loaders\n",
    "train_dataset = TensorDataset(\n",
    "    transformText(train_text, max_len), torch.tensor(train_label)\n",
    ")\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size)\n",
    "\n",
    "val_dataset = TensorDataset(transformText(val_text, max_len), torch.tensor(val_label))\n",
    "val_loader = DataLoader(val_dataset, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. <a name=\"6\">Using pre-trained GloVe Word Embeddings</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "In this example, we will use GloVe word vectors. `name='6B'` `dim=300` gives us 6 billion words/phrases vectors. Each word vector has 300 numbers in it. The following code shows how to get the word vectors and create an embedding matrix from them. We will connect our vocabulary indexes to the GloVe embedding with the `get_vecs_by_tokens()` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.868989Z",
     "start_time": "2021-01-09T05:04:29.866241Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ".vector_cache/glove.6B.zip: 862MB [03:34, 4.02MB/s]                               \n",
      "100%|█████████▉| 399999/400000 [00:53<00:00, 7467.97it/s]\n"
     ]
    }
   ],
   "source": [
    "glove = GloVe(name=\"6B\", dim=300)\n",
    "embedding_matrix = glove.get_vecs_by_tokens(vocab.get_itos())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. <a name=\"7\">Setting Hyperparameters and Bulding the Network</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will set our parameters like below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.880059Z",
     "start_time": "2021-01-09T05:04:29.871107Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Size of the state vectors\n",
    "hidden_size = 8\n",
    "\n",
    "# General NN training parameters\n",
    "learning_rate = 0.001\n",
    "epochs = 25\n",
    "\n",
    "# Embedding vector and vocabulary sizes\n",
    "embed_size = 300  # glove.6B.300d.txt\n",
    "vocab_size = len(vocab.get_itos())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to put our data into correct format before the process.\n",
    "Our model is made of these layers:\n",
    "* Embedding layer: This is where our words/tokens are mapped to word vectors.\n",
    "* RNN layer: We are using a simple RNN model. We stack 2 RNN layers in this example. More details about the RNN are available [here](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html).\n",
    "* Linear layer: A linear layer with a single neuron is used to output the `isPositive` prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.892791Z",
     "start_time": "2021-01-09T05:04:29.881808Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.RNN(\n",
    "            embed_size, hidden_size, num_layers=num_layers\n",
    "        )\n",
    "\n",
    "        self.linear = nn.Linear(hidden_size*max_len, 1)\n",
    "        self.act = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        embeddings = self.embedding(inputs)\n",
    "        # Call RNN layer\n",
    "        outputs, _ = self.rnn(embeddings)\n",
    "        # Use the output of each time step\n",
    "        # Send it all together to the linear layer\n",
    "        outs = self.linear(outputs.reshape(outputs.shape[0], -1))\n",
    "        return self.act(outs)\n",
    "    \n",
    "model = Net(vocab_size, embed_size, hidden_size, num_layers=2)\n",
    "\n",
    "# Initialize the weights\n",
    "def init_weights(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "    if type(m) == nn.RNN:\n",
    "        for param in m._flat_weights_names:\n",
    "            if \"weight\" in param:\n",
    "                nn.init.xavier_uniform_(m._parameters[param])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's initialize this network. Then, we will need to make the embedding layer use our GloVe word vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.902048Z",
     "start_time": "2021-01-09T05:04:29.899284Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# We set the embedding layer's parameters from GloVe\n",
    "model.embedding.weight.data.copy_(embedding_matrix)\n",
    "# We won't change/train the embedding layer\n",
    "model.embedding.weight.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. <a name=\"8\">Training the Network</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "Now, it is time to start our training. We define the loss function and training algorithm first. Then, training starts!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:04:29.906415Z",
     "start_time": "2021-01-09T05:04:29.903716Z"
    }
   },
   "source": [
    "We will define the trainer and loss function below. \n",
    "\n",
    "__Binary cross-entropy loss__ is used as this is a binary classification problem.\n",
    "\n",
    "$$\n",
    "\\mathrm{BinaryCrossEntropyLoss} = -\\sum_{examples}{(y\\log(p) + (1 - y)\\log(1 - p))}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Setting our trainer\n",
    "trainer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# We will use Binary Cross-entropy loss\n",
    "# reduction=\"sum\" sums the losses for given output and target\n",
    "cross_ent_loss = nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:06:35.434926Z",
     "start_time": "2021-01-09T05:04:29.908071Z"
    }
   },
   "source": [
    "Now, it is time to start the training process. We will print the Binary cross-entropy loss loss after each epoch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see some validation results below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-09T05:06:36.046946Z",
     "start_time": "2021-01-09T05:06:35.436633Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Train_loss 0.5749193862184427. Val_loss 0.4952433282331937. Seconds 36.04134702682495\n",
      "Epoch 1. Train_loss 0.48871823963112654. Val_loss 0.4601000900148647. Seconds 33.45679497718811\n",
      "Epoch 2. Train_loss 0.45644251649493167. Val_loss 0.44080355651447783. Seconds 34.86459231376648\n",
      "Epoch 3. Train_loss 0.43646846645514925. Val_loss 0.43376576227705077. Seconds 37.549219369888306\n",
      "Epoch 4. Train_loss 0.4238422774737129. Val_loss 0.43014194645495357. Seconds 38.484665393829346\n",
      "Epoch 5. Train_loss 0.41523077599075586. Val_loss 0.4284431818196187. Seconds 31.103814363479614\n",
      "Epoch 6. Train_loss 0.40890235627330973. Val_loss 0.427571468531089. Seconds 33.900928020477295\n",
      "Epoch 7. Train_loss 0.4039536820081941. Val_loss 0.42696130284242484. Seconds 33.89893341064453\n",
      "Epoch 8. Train_loss 0.39989804680101526. Val_loss 0.42630732740772026. Seconds 30.393486738204956\n",
      "Epoch 9. Train_loss 0.3964855804182794. Val_loss 0.4253865909467409. Seconds 34.42492699623108\n",
      "Epoch 10. Train_loss 0.39357005691567687. Val_loss 0.4241878995726425. Seconds 32.364160776138306\n",
      "Epoch 11. Train_loss 0.39101992983293365. Val_loss 0.42278922661864293. Seconds 33.451287031173706\n",
      "Epoch 12. Train_loss 0.38874290612647466. Val_loss 0.42128112265375245. Seconds 32.36166453361511\n",
      "Epoch 13. Train_loss 0.38667908396007244. Val_loss 0.41981661798341185. Seconds 32.5251944065094\n",
      "Epoch 14. Train_loss 0.3847890744829816. Val_loss 0.418469389926912. Seconds 35.70081067085266\n",
      "Epoch 15. Train_loss 0.3830510070554115. Val_loss 0.4172244998178647. Seconds 33.57869815826416\n",
      "Epoch 16. Train_loss 0.381443474713997. Val_loss 0.41606583852804735. Seconds 33.65086340904236\n",
      "Epoch 17. Train_loss 0.3799418716660873. Val_loss 0.4149827193934537. Seconds 33.91373872756958\n",
      "Epoch 18. Train_loss 0.3785265276843438. Val_loss 0.4139553481365105. Seconds 35.87880253791809\n",
      "Epoch 19. Train_loss 0.3771835694714041. Val_loss 0.41296302904348947. Seconds 34.16612648963928\n",
      "Epoch 20. Train_loss 0.3759035045440233. Val_loss 0.41199824469040797. Seconds 34.193938970565796\n",
      "Epoch 21. Train_loss 0.3746786926965275. Val_loss 0.4110676825327573. Seconds 33.171313524246216\n",
      "Epoch 22. Train_loss 0.3735027448530206. Val_loss 0.41018177632145036. Seconds 32.207141637802124\n",
      "Epoch 23. Train_loss 0.3723720128504165. Val_loss 0.40933081309681946. Seconds 35.42681431770325\n",
      "Epoch 24. Train_loss 0.37128574860400865. Val_loss 0.4085043627699029. Seconds 33.7281653881073\n"
     ]
    }
   ],
   "source": [
    "# Get the compute device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model.apply(init_weights)\n",
    "model.to(device)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    start = time.time()\n",
    "    training_loss = 0\n",
    "    val_loss = 0\n",
    "    # Training loop, train the network\n",
    "    for data, target in train_loader:\n",
    "        trainer.zero_grad()\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "        output = model(data)\n",
    "        L = cross_ent_loss(output.squeeze(1), target)\n",
    "        training_loss += L.item()\n",
    "        L.backward()\n",
    "        trainer.step()\n",
    "\n",
    "    # Validate the network, no training (no weight update)\n",
    "    for data, target in val_loader:\n",
    "        val_predictions = model(data.to(device))\n",
    "        L = cross_ent_loss(val_predictions.squeeze(1), target.to(device))\n",
    "        val_loss += L.item()\n",
    "\n",
    "    # Let's take the average losses\n",
    "    training_loss = training_loss / len(train_label)\n",
    "    val_loss = val_loss / len(val_label)\n",
    "\n",
    "    end = time.time()\n",
    "    print(\n",
    "        f\"Epoch {epoch}. Train_loss {training_loss}. Val_loss {val_loss}. Seconds {end-start}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. <a name=\"9\">Test the classifier on the validation data</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "Let's get the validation predictions. Earlier we made predictions on the validation set with this line: ```model(data.to(device))```."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n"
     ]
    }
   ],
   "source": [
    "val_predictions = []\n",
    "for data, target in val_loader:\n",
    "    val_preds = model(data.to(device))\n",
    "    val_predictions.extend(\n",
    "        [np.rint(val_pred)[0] for val_pred in val_preds.detach().cpu().numpy()]\n",
    "    )\n",
    "print(val_predictions[:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Confusion matrix, classification report and accuracy score are printed below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2175  475]\n",
      " [ 751 3598]]\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       0.74      0.82      0.78      2650\n",
      "         1.0       0.88      0.83      0.85      4349\n",
      "\n",
      "    accuracy                           0.82      6999\n",
      "   macro avg       0.81      0.82      0.82      6999\n",
      "weighted avg       0.83      0.82      0.83      6999\n",
      "\n",
      "Accuracy (validation): 0.8248321188741249\n"
     ]
    }
   ],
   "source": [
    "# Use the fitted pipeline to make predictions on the validation dataset\n",
    "print(confusion_matrix(val_label, val_predictions))\n",
    "print(classification_report(val_label, val_predictions))\n",
    "print(\"Accuracy (validation):\", accuracy_score(val_label, val_predictions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. <a name=\"10\">Improvement ideas</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "We can improve our model by\n",
    "* Changing hyper-parameters: Learning rate, batch size and hidden size\n",
    "* Increasing the number of layers: num_layers\n",
    "* Using more advanced architetures such as Gated Recurrent Units (GRU) and Long Short-term Memory networks (LSTM)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sagemaker-distribution:Python",
   "language": "python",
   "name": "conda-env-sagemaker-distribution-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
